import torch
import torch.nn as nn

def loss(pred, y_true, prior):
    pass

def ramp_loss(pred):
    return torch.clamp(1-pred,0,2)/2

def squared_loss(pred):
    return 1/4*torch.pow(pred-1,2)

def double_hinge_loss(pred):
    temp = 1/2 - 1/2*pred
    temp1 = torch.max(torch.tensor([0.0]), temp)
    temp2 = torch.max(-pred, temp1)
    return temp2

def hinge_loss(pred):
    return torch.clamp(1-pred,min=0)

def sigmoid_loss(pred):
    loss = torch.sigmoid(-pred)
    return loss

def zero_one_loss(pred):
    loss = torch.ones_like(pred)
    loss[pred>0] = 0
    return loss

def logistic_loss(pred):
     negative_logistic = nn.LogSigmoid()
     logistic = -1. * negative_logistic(pred)
     return logistic

def exp_loss(pred):
    return torch.exp(-pred)

def unhinged_loss(pred):
    return 1-pred
    
def multi_class_loss(pred, Y_test): # pred is n by k 
    k = pred.shape[1]
    positive_loss_matrix = hinge_loss(pred)
    negative_loss_matrix = hinge_loss(-pred)
    labeled_loss = (positive_loss_matrix*Y_test).sum(dim=-1)
    unlabeled_loss = (negative_loss_matrix*(1-Y_test)).sum(dim=-1)
    loss = labeled_loss + 1.0/(k-1)*unlabeled_loss
    return loss.mean()

